import os
import logging
from typing import Optional

from gpustack.client import ClientSet
from gpustack.client.worker_manager_clients import (
    WorkerStatusClient,
    WorkerRegistrationClient,
)
from gpustack.config.config import Config
from gpustack.schemas.workers import (
    WorkerCreate,
    WorkerUpdate,
)
from gpustack.schemas.users import (
    UserPublic,
)
from gpustack.security import API_KEY_PREFIX
from gpustack.utils import platform
from gpustack.worker.collector import WorkerStatusCollector
from gpustack.config.registration import (
    registration_client,
    read_worker_token,
    write_worker_token,
)
from gpustack.utils.profiling import time_decorator

logger = logging.getLogger(__name__)


class WorkerManager:
    _is_embedded: bool
    _collector: WorkerStatusCollector
    _clientset: Optional[ClientSet] = None
    _registration_client: WorkerRegistrationClient
    _status_client: WorkerStatusClient
    # worker name is used for logging
    _worker_name: str

    def __init__(
        self,
        cfg: Config,
        is_embedded: bool,
        collector: WorkerStatusCollector,
        worker_name: str,
    ):
        self._is_embedded = is_embedded
        self._cfg = cfg
        self._collector = collector
        self._worker_name = worker_name
        worker_token = read_worker_token(self._cfg.data_dir)
        if worker_token:
            self._prepare_clients(worker_token)

    def _prepare_clients(self, token: str):
        self._clientset = ClientSet(
            base_url=self._cfg.get_server_url(),
            api_key=token,
        )
        self._status_client = WorkerStatusClient(self._clientset.http_client)

    def sync_worker_status(self):
        """
        Should be called periodically to sync the worker node status with the server.
        It registers the worker node with the server if necessary.
        """
        if self._status_client is None:
            return
        try:
            workerStatus = self._collector.collect(self._clientset)
        except Exception as e:
            logger.error(f"Failed to collect status for worker: {e}")
            return
        try:
            self._status_client.create(workerStatus)
        except Exception as e:
            logger.error(f"Failed to update worker status: {e}")

    def register_with_server(self) -> ClientSet:
        # If the worker has been registered, self._clientset should be valid.
        # the clientset is built in WorkerManager.__init__ if cfg._worker_token is stored.
        try:
            if self._clientset:
                response = self._clientset.http_client.get_httpx_client().get(
                    "/users/me"
                )
                if response.status_code == 200:
                    user = UserPublic.model_validate(response.json())
                    if user.worker_id is not None:
                        return self._clientset
                logger.warning(
                    "Existing worker_token is invalid, trying to re-register."
                )
        except Exception:
            raise
        try:
            token = self._register_worker()
            write_worker_token(self._cfg.data_dir, token)
            self._prepare_clients(token)
            return self._clientset
        except Exception as e:
            logger.error(f"Failed to register worker: {e}")
            raise

    def _register_worker(self) -> str:
        logger.info(
            f"Registering worker: {self._worker_name}",
        )
        # always reloads the token
        self._cfg.reload_token()
        self._registration_client = registration_client(
            data_dir=self._cfg.data_dir,
            server_url=self._cfg.get_server_url(),
            registration_token=self._cfg.token,
            wait_token_file=self._is_embedded,
        )
        external_id = None
        external_id_path = os.path.join(self._cfg.data_dir, 'external_id')
        if os.path.exists(external_id_path):
            with open(os.path.join(self._cfg.data_dir, 'external_id'), 'r') as f:
                external_id = f.read()

        @time_decorator
        def timed_collect():
            return self._collector.collect(initial=True)

        workerStatus = timed_collect()
        workerUpdate = WorkerUpdate(
            name=self._worker_name,
            labels=self._ensure_builtin_labels(),
        )
        to_register = WorkerCreate.model_validate(
            {
                **workerStatus.model_dump(),
                **workerUpdate.model_dump(),
                "external_id": external_id,
            }
        )
        created = self._registration_client.create(to_register)
        logger.info(
            f"Worker {self._worker_name} registered with worker_id {created.id}."
        )
        return created.token

    def _register_shutdown_hooks(self):
        pass

    def _ensure_builtin_labels(self) -> dict:
        labels = {
            "os": platform.system(),
            "arch": platform.arch(),
            "worker-name": self._worker_name,
        }

        # Legacy workers with version 0.7.x send worker_uuid as part of registration.
        # Legacy workers with version <0.7.x don't have worker_uuid, so we use this label as part of the registration allowance.
        if self._cfg.token and not self._cfg.token.startswith(API_KEY_PREFIX):
            labels["gpustack.existence-check"] = "true"
        return labels
